{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a5ec7140",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-14T02:42:20.997483Z",
     "start_time": "2024-05-14T02:42:19.666390Z"
    }
   },
   "outputs": [],
   "source": [
    "import gym\n",
    "import numpy as np\n",
    "from IPython import display\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "42dba77c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-14T02:42:25.010890Z",
     "start_time": "2024-05-14T02:42:25.005359Z"
    }
   },
   "outputs": [],
   "source": [
    "class GymHelper:\n",
    "    def __init__(self,env,figsize=(3,3)):\n",
    "        self.env=env\n",
    "        self.figsize=figsize\n",
    "        plt.figure(figsize=figsize)\n",
    "        self.img=plt.imshow(env.render())\n",
    "    def render(self,title=None):\n",
    "        img_data=self.env.render()\n",
    "        self.img.set_data(img_data)\n",
    "        display.display(plt.gcf())\n",
    "        display.clear_output(wait=True)\n",
    "        if title:\n",
    "            plt.title(title)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "425cb713",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-14T02:42:32.836645Z",
     "start_time": "2024-05-14T02:42:30.804844Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from tqdm import *\n",
    "import collections\n",
    "import time\n",
    "import random\n",
    "import sys\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6eedf826",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-14T02:45:37.330467Z",
     "start_time": "2024-05-14T02:45:28.113439Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAASAAAADbCAYAAADNoUzuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAes0lEQVR4nO3dfVSUdf438Pc1MDMMTwMDMsOICvms5AMYqRniA7AqmXW3t+Vju+1v105Q3tV97tx2V9vTCdY9229Pp8yz1brtHxu2q2117vIWS0mj/W0LkmA+FvJgIKAyAwoMzHzuP8j57agoIPBl8P0653tOXNeXaz7zzXnzne91zTWaiAiIiBTQqS6AiG5fDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwg6pOdO3di6tSpMJlM0DQNK1asgKZpfTrWgQMHoGkaDhw40Kvfi4+PR1ZWVp8ek4aGQNUFkP9paGjA2rVr8YMf/ADbtm2D0WiE3W7Hc88916fjJSUl4YsvvsCUKVP6uVIa6hhA1GsnT55ER0cH1qxZg/nz53u3jx49uk/HCw8Px+zZs/urPPIjfAtGvfLoo49i3rx5AICVK1dC0zSkpaVhy5Yt17wFu/IWac+ePUhKSoLJZMKkSZPwxz/+0aff9d6Cffvtt3j44Ydht9thNBphtVqxaNEilJaWXlPTzY5PQxdnQNQrv/zlL5GSkoInnngCL730EhYsWIDw8HC8++671+3/1Vdf4ZlnnsFzzz0Hq9WKN998E4899hjGjRuH1NTUbh9n6dKlcLvd2Lp1K0aPHo3GxkYUFRWhqampX45PQwMDiHpl7Nix3rWa8ePH3/StU2NjIz7//HPv27PU1FR88skn+Mtf/tJtQJw/fx4nTpzA73//e6xZs8a7/cEHH+yX49PQwQCiATVjxgyftaGgoCBMmDABlZWV3f6OxWLB2LFj8dvf/hZutxsLFizA9OnTodNdu2LQl+PT0ME1IBpQUVFR12wzGo1obW3t9nc0TcMnn3yCzMxMbN26FUlJSRgxYgSefPJJNDc33/LxaejgDIiGpDFjxuCtt94C0HXW7d1338WWLVvgcrmwfft2xdVRf+EMiIa8CRMm4Be/+AXuvPNOlJSUqC6H+hFnQDTkHDlyBNnZ2fjhD3+I8ePHw2Aw4NNPP8WRI0f6fLEjDU0MIBpybDYbxo4di23btqG6uhqapuGOO+7A7373O+Tk5Kguj/qRxm/FICJVuAZERMowgIhIGQYQESmjNIC2bduGhIQEBAUFITk5GQcPHlRZDhENMmUBtHPnTmzcuBHPP/88Dh8+jHvvvRdLlixBVVWVqpKIaJApOwt29913IykpCa+//rp32+TJk7FixQrk5uaqKImIBpmS64BcLheKi4uvuagsIyMDRUVF1/Rvb29He3u792ePx4MLFy4gKiqqz7cBJaKBIyJobm6G3W6/7oeIr1ASQI2NjXC73bBarT7brVYr6urqrumfm5uLF154YbDKI6J+Ul1djbi4uG73K12Evnr2IiLXndFs2rQJDofD27hOROQfwsLCbrhfyQwoOjoaAQEB18x26uvrr5kVAV23VzAajYNVHhH1k5stkSiZARkMBiQnJ6OgoMBne0FBAebOnauiJCJSQNmHUZ9++mmsXbsWs2bNwpw5c/CHP/wBVVVV2LBhg6qSiGiQKQuglStX4vz58/j1r3+N2tpaJCYm4qOPPsKYMWNUlUREg8wvPw3vdDphNptVl0FEN+FwOBAeHt7tfn4WjIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMr0OoA+++wz3HfffbDb7dA0DX//+9999osItmzZArvdDpPJhLS0NBw9etSnT3t7O3JychAdHY2QkBAsX74cNTU1t/REiMj/9DqALl26hOnTp+PVV1+97v6tW7fi5Zdfxquvvoovv/wSNpsN6enpaG5u9vbZuHEj3nvvPeTn5+PQoUNoaWlBVlYW3G53358JEfkfuQUA5L333vP+7PF4xGazSV5enndbW1ubmM1m2b59u4iINDU1iV6vl/z8fG+fs2fPik6nkz179vTocR0OhwBgY2Mb4s3hcNzwtdyva0AVFRWoq6tDRkaGd5vRaMT8+fNRVFQEACguLkZHR4dPH7vdjsTERG+fq7W3t8PpdPo0IvJ//RpAdXV1AACr1eqz3Wq1evfV1dXBYDAgMjKy2z5Xy83Nhdls9rZRo0b1Z9lEpMiAnAXTNM3nZxG5ZtvVbtRn06ZNcDgc3lZdXd1vtRKROv0aQDabDQCumcnU19d7Z0U2mw0ulwsXL17sts/VjEYjwsPDfRoR+b9+DaCEhATYbDYUFBR4t7lcLhQWFmLu3LkAgOTkZOj1ep8+tbW1KC8v9/YhottDYG9/oaWlBadPn/b+XFFRgdLSUlgsFowePRobN27ESy+9hPHjx2P8+PF46aWXEBwcjFWrVgEAzGYzHnvsMTzzzDOIioqCxWLBs88+izvvvBOLFy/uv2dGRENfj8+5f2///v3XPd22fv16Eek6Fb9582ax2WxiNBolNTVVysrKfI7R2toq2dnZYrFYxGQySVZWllRVVfW4Bp6GZ2Pzj3az0/CaiAj8jNPphNlsVl0GEd2Ew+G44ZotPwtGRMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZXoVQLm5ubjrrrsQFhaGmJgYrFixAidOnPDpIyLYsmUL7HY7TCYT0tLScPToUZ8+7e3tyMnJQXR0NEJCQrB8+XLU1NTc+rMhIv8ivZCZmSk7duyQ8vJyKS0tlWXLlsno0aOlpaXF2ycvL0/CwsJk165dUlZWJitXrpTY2FhxOp3ePhs2bJCRI0dKQUGBlJSUyIIFC2T69OnS2dnZozocDocAYGNjG+LN4XDc8LXcqwC6Wn19vQCQwsJCERHxeDxis9kkLy/P26etrU3MZrNs375dRESamppEr9dLfn6+t8/Zs2dFp9PJnj17evS4DCA2Nv9oNwugW1oDcjgcAACLxQIAqKioQF1dHTIyMrx9jEYj5s+fj6KiIgBAcXExOjo6fPrY7XYkJiZ6+1ytvb0dTqfTpxGR/+tzAIkInn76acybNw+JiYkAgLq6OgCA1Wr16Wu1Wr376urqYDAYEBkZ2W2fq+Xm5sJsNnvbqFGj+lo2EQ0hfQ6g7OxsHDlyBO+88841+zRN8/lZRK7ZdrUb9dm0aRMcDoe3VVdX97VsIhpC+hRAOTk5+OCDD7B//37ExcV5t9tsNgC4ZiZTX1/vnRXZbDa4XC5cvHix2z5XMxqNCA8P92lE5P96FUAiguzsbOzevRuffvopEhISfPYnJCTAZrOhoKDAu83lcqGwsBBz584FACQnJ0Ov1/v0qa2tRXl5ubcPEd0menPW6/HHHxez2SwHDhyQ2tpab7t8+bK3T15enpjNZtm9e7eUlZXJI488ct3T8HFxcbJv3z4pKSmRhQsX8jQ8G9swbP16Gr67B9mxY4e3j8fjkc2bN4vNZhOj0SipqalSVlbmc5zW1lbJzs4Wi8UiJpNJsrKypKqqqsd1MIDY2Pyj3SyAtO+Dxa84nU6YzWbVZRDRTTgcjhuu2fKzYESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTKBqgsg/6BpwIMPAhERQFER0NwM1NYCbrfqym7OagUeeACoqwNKS7tqP39edVUEMICohzQNuP9+YPJk4LHHAJcLOHkScDiAzz4DmpqAsrKu7Zcvq67W14gRwI9+1PXfnZ1d4VNdDXzzDXD0KFBVBdTUAG1tQEeH2lpvNwwg6hVNAwyGrpaU1LUtLa1rJnT+PHDuHPD8812zo6FE0wCdDggIAOz2rpaS0rXv0qWuWVFBAfDqq4DHo7bW2wkDiHpNpOtF2tTU9eItLgYuXAA+/7zrhXzunOoKr+/KjWfa2oCWlq5Z0OnTwIkTXbO5+nqGz2BjAFGPXbhgw6FDdd63XF99BbS3d4XQQN5VKigoDB0dbXC7+/b+SKcLw7FjgpMnW1Bc3BWQp093vV1sb+/nYqlXGEDUI5oWgMOHM/H2228P6uOagsxYPPc5tLQ2oLj8HTibe//eTq+fgJdf7sRXX301ABXSrWAA0ZCl15uQlvK/MOOOVdCgYdzIBfj8q9dw6swBdHS0qS6P+gGvA6IhKTAwCKl35SB5/DoE6PTQ6QJhj0zC8nn/iR/cuwXRUXdA0/jP19/x/yANCZoWAKDriykN+mCkpWzE7Mk/RYDO+G99NATpzUgauwYr09/EjKn/AyZThJqCqV8wgEg5kykCdyetR2REHPSBJqTe9STumvQoAnWm635brqbpMCJ8EpakvITMe36JOPsMBAToFVROt4prQKSMpukw0jYNaXc9g1HRKTAaQtBw/htMH/8/YQgIu+nXeRsCQzA94WHEWqaj+OSfUfr13+ByDbGrIOmGOAMiJQyGEKRMX4eHFr2O2Mhp+PLEDpyuOIg5036KUEPMTcPnCk3TEGOejDGxsyHCi3h6w2AwwGAwQKdTFwOcAdGgi4wcjcWzN2Fc7EKcbz6Nz0p+hW8qDyJ93vOwmRNxZS2opzo9bTh+5v/xzFgPhISEYPz48UhPT0dmZiaMRiOOHz+OEydO4Pjx4/jmm2/Q2NiIpqYmdAzC51IYQDToQk0jEGVOQPmZ93CoZBsuNlXBFBQBvSEInZ42BOpMPT6WiOB8yzeoazgKS+RoeMSDpqaafq9Zp9PB44eXSWuahrCwMEybNg3p6elIT0/H1KlTERoa6p353HPPPRARuN1utLW14dy5c6ipqcGpU6dQXl6OU6dO4cyZM6itrUVraytcLhf66/tMGUA06GpqS/F/P/s56hqOoaOjFQDQ2taEfQd/g3MTj2PGxB8iOmwiAjRDD96KCU7XfIJRtllISXwUp2o+xSeHtvZrvcFBQZgUH4/jZ87gctvQn2VpmgaLxYKZM2diyZIlWLRoEcaNG4fg4OBuFvU1aJoGnU4HvV6PsLAwjBs3DmlpaRARdHR0oKWlBRcvXsS3336LU6dO4fjx4zh27BiqqqrQ0NAAp9MJj8fT62BiANGgE3Gj+ruSa7Zfam3Ef321A5Vn/wvJiaswYVQGQo1W6LSAbo91ueMCjlfshckQgWBjNCxhYxAQYIDb7eq3el0dHWhsaoJrCH9UPiAgAFarFSkpKVi6dClSU1ORkJAAvV7f4/W069E0DQaDARaLBRaLBWPHjsXixYshIvB4PGhubsa5c+dQXV2NY8eO4dixYzh9+jROnz6NM2fO3PT4vVp9ev311zFt2jSEh4cjPDwcc+bMwccff+zdLyLYsmUL7HY7TCYT0tLScPToUZ9jtLe3IycnB9HR0QgJCcHy5ctRU9P/U2byTyIe1DaUY++hF7H3ixdwpuEQ2jubr/uXVURQc/5fqG84gSbnWbR3OBEVPg5BxrB+ranT7UZVXR06h9jNjwIDAxEfH49Vq1bhz3/+M4qKirBz50785Cc/wcSJE2Ew9GQG2XtXZkuBgYGIjIzEpEmTkJ6ejieffBKvvfYaPvjgAxw4cKBnz6E3DxwXF4e8vDyMGzcOAPD222/j/vvvx+HDhzF16lRs3boVL7/8Mv70pz9hwoQJePHFF5Geno4TJ04gLKzrH8XGjRvx4YcfIj8/H1FRUXjmmWeQlZWF4uJiBAR0/5eObi+ujssoP/UhqusOY+aUHyJx7P2IDE5AgO6/r/dp63Tg8LF8uDou43LrBTgvf4fo8AkwGIJx6fLA3XHMGhoKEUH9pUsD9hjdMRqNiI+Px4IFC7BkyRKkpKQgJibG+zZKNZ1OB5PJhMjIyJ79gtyiyMhIefPNN8Xj8YjNZpO8vDzvvra2NjGbzbJ9+3YREWlqahK9Xi/5+fnePmfPnhWdTid79uzp8WM6HA4BwDaILSAgQNavX39LxxhnNkuMyeT9WQMk+8475ZEJE0S7we/pdHqJHzlHHvrBq/K/Hz0im/+jRn71H9Vy/+Lfij7Q9H2fQHkg42X5Pz8ql6kTlvn8fnJyskyfPr3fxmJ+fLzMGz160MY+ODhYZs6cKc8++6zs27dPGhsbpbOz81ZfugPqymvU4XDcsF+f14Dcbjf++te/4tKlS5gzZw4qKipQV1eHjIwMbx+j0Yj58+ejqKgIP/vZz1BcXIyOjg6fPna7HYmJiSgqKkJmZmZfy6EhTq/T4RfJyTh28SJ+c/gwgK6p/FizGQ9aLNhz5gwuuq6/buPxdODM2S9wrvEYpk7IwsxJKxEYEIQvSt9ER2fr933caLh4CmPtC2Aw9vwsWl8UVVVBBvQRutZ0JkyYgAcffBBZWVmYPHkywsPDAWBIzHT6S68DqKysDHPmzEFbWxtCQ0Px3nvvYcqUKSgqKgIAWK1Wn/5WqxWVlZUAgLq6OhgMhmumZ1arFXV1dd0+Znt7O9r/7cYtTqezt2XTLdDpdIiLi0NERATWrVvX5+OcDAlBW2cn1t15p3fb1wEBuGAy4b6HH+7RMTToEG75ChERI7HEOAcid3v3xVoTEGM5g3tCEnDnzP+uMzQ0FKdOnUJoaChaWlr6XP8VHQN4Oj4sLAz33HMP1q5di/T0dERFRSm9UHCg9TqAJk6ciNLSUjQ1NWHXrl1Yv349CgsLvfuvTmcRuWli36xPbm4uXnjhhd6WSrfIaDRixowZWL16NZYvX464uLghtk73aDfbV/j8JCJoaWnBkSNHsHv3bnz00Uc4ffo0Ojs7B7rAHtE0DXFxcVi+fDlWr16NGTNmwGQa2FnckHGr7/UWLVokP/3pT+Wbb74RAFJSUuKzf/ny5bJu3ToREfnkk08EgFy4cMGnz7Rp0+RXv/pVt4/R1tYmDofD26qrq5WviQznNmLECFm1apV8/PHH4nQ6xePx3Oo/kyHB4/GI2+2WhoYGef/992XNmjUycuRI0el0SsY5KChIZs+eLa+88opUVlaK2+0eNmPd0zWgWw6ghQsXyvr1672L0L/5zW+8+9rb26+7CL1z505vn++++46L0EOgBQYGypQpU2TLli1y9OhRcblct/pPY0jzeDzicrmkoqJC3njjDcnMzJSIiAjRNG1Ax1nTNImOjvYGvMPhGDah8+8GJIA2bdokn332mVRUVMiRI0fk5z//ueh0Otm7d6+IiOTl5YnZbJbdu3dLWVmZPPLIIxIbGytOp9N7jA0bNkhcXJzs27dPSkpKZOHChTJ9+vRereozgPqvhYWFyZIlS+Sdd96RhoYG8Xg8w/IFcSMej0fa2trkyJEj8uKLL8qsWbPE9G9n6/qjBQQEyNSpU+WFF16QY8eOSUdHx7Ae5wEJoB//+McyZswYMRgMMmLECFm0aJE3fES6/kdu3rxZbDabGI1GSU1NlbKyMp9jtLa2SnZ2tlgsFjGZTJKVlSVVVVW9KYMBdItNp9PJmDFj5Mknn5R//OMf0traOqxfDL3hdrvF6XTK/v37JScnR8aNGyd6vf6WAn7p0qXegHe73aqf4qDoaQBpIgP5fQYDw+l0wmw2qy7D7wQFBWHmzJlYs2YNsrKyEBcXN6zPsNwqj8eDhoYGHDx4EH/7299QWFiIc+fO3fTzTpqmYdSoUVixYgVWr16NadOmISgoaJCqHhquvEYdDof38oHrYQANc5qmYcSIEcjIyMC6deswe/ZshIaGDqtrSQaafP9J8crKSuzduxe7du3Cv/71LzgcDp9+JpPJG/D33XcfYmNjh9hZw8HDALrNBQYGYuLEiVi5ciUeeughjBs3Dno9b1t6q0QEra2tOHbsGN5//318+OGHqK2txeLFi7F27VrMnTuXAQ8G0G3LbDbj3nvvxbp167Bw4UJYLJbb/sUwEK68bJxOJ5qamjBy5EgEBARwrL/X0wDi7TiGgYCAAIwePRorVqzAqlWrkJiYCKPRyBfDALoytmazmX8MbwEDyI8FBQUhOTkZa9euxbJly2C324fMp6KJeoIB5Gc0TUNMTAwyMzOxbt063H333QgJCWHokF9iAPmJwMBATJ48GQ8//DAeeugh793uiPwZA2iIM5vNSE1NxaOPPoq0tDRERkZytkPDBgPoe1duM6nT6aBpGjweD9xud7/d/b83dDodxowZgwceeACrVq3C1KlTuahMw9JtE0B6vR4mk8nboqKiYLVaYbVaERMTA5vN5m3h4eGorKzE8ePHcfLkSZw8eRLV1dVwOBxobr7+/Yn7g8lkwqxZs7yLyjabjYvKNKz5dQAFBAQgMDAQgYGBCA4ORmRkJCwWCyIjIxETE4PY2FjY7XbExsYiKioKkZGRiIyMREREBAwGAwICAqDT6a57/UZSUpL3CtjOzk40NzejpqYGVVVVOH78OL7++mtUVFSgqqoKjY2NcLlcffoiN51OhxEjRmDp0qVYu3YtUlJSuv36FKLhxq8vRHzrrbcQHx+PmJgYREREICgoyNsG+qKwK1fENjc3o76+HqdOncKpU6e83zD53Xff4cKFC2hpabnujEmv12PKlCneReX4+HgEBvr13wMir9viSuibPbnBJt9/V5Lb7cbFixdx7tw5VFZWer8vqaKiAtXV1Zg2bRrWr1+P+fPnIyIigrMdGnYYQEOMx+NBW1sbWltbERoaOmDf2UQ0FPCjGEOMTqdDcHAwgoODVZdCNGTwZjBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImX88oZkV27i6HQ6FVdCRNdz5bV5sxuu+mUANTc3AwBGjRqluBIiupHm5maYzeZu9/vlPaE9Hg9OnDiBKVOmoLq62m/uC62a0+nEqFGjOGa9xHHrPRFBc3Mz7HY7dLruV3r8cgak0+kwcuRIAEB4eDj/UfQSx6xvOG69c6OZzxVchCYiZRhARKSM3waQ0WjE5s2bYTQaVZfiNzhmfcNxGzh+uQhNRMOD386AiMj/MYCISBkGEBEpwwAiImX8MoC2bduGhIQEBAUFITk5GQcPHlRdkjK5ubm46667EBYWhpiYGKxYsQInTpzw6SMi2LJlC+x2O0wmE9LS0nD06FGfPu3t7cjJyUF0dDRCQkKwfPly1NTUDOZTUSY3NxeapmHjxo3ebRyzQSJ+Jj8/X/R6vbzxxhvy9ddfy1NPPSUhISFSWVmpujQlMjMzZceOHVJeXi6lpaWybNkyGT16tLS0tHj75OXlSVhYmOzatUvKyspk5cqVEhsbK06n09tnw4YNMnLkSCkoKJCSkhJZsGCBTJ8+XTo7O1U8rUHzz3/+U+Lj42XatGny1FNPebdzzAaH3wVQSkqKbNiwwWfbpEmT5LnnnlNU0dBSX18vAKSwsFBERDwej9hsNsnLy/P2aWtrE7PZLNu3bxcRkaamJtHr9ZKfn+/tc/bsWdHpdLJnz57BfQKDqLm5WcaPHy8FBQUyf/58bwBxzAaPX70Fc7lcKC4uRkZGhs/2jIwMFBUVKapqaHE4HAAAi8UCAKioqEBdXZ3PmBmNRsyfP987ZsXFxejo6PDpY7fbkZiYOKzH9YknnsCyZcuwePFin+0cs8HjVx9GbWxshNvthtVq9dlutVpRV1enqKqhQ0Tw9NNPY968eUhMTAQA77hcb8wqKyu9fQwGAyIjI6/pM1zHNT8/HyUlJfjyyy+v2ccxGzx+FUBXaJrm87OIXLPtdpSdnY0jR47g0KFD1+zry5gN13Gtrq7GU089hb179yIoKKjbfhyzgedXb8Gio6MREBBwzV+Y+vr6a/5a3W5ycnLwwQcfYP/+/YiLi/Nut9lsAHDDMbPZbHC5XLh48WK3fYaT4uJi1NfXIzk5GYGBgQgMDERhYSFeeeUVBAYGep8zx2zg+VUAGQwGJCcno6CgwGd7QUEB5s6dq6gqtUQE2dnZ2L17Nz799FMkJCT47E9ISIDNZvMZM5fLhcLCQu+YJScnQ6/X+/Spra1FeXn5sBzXRYsWoaysDKWlpd42a9YsrF69GqWlpbjjjjs4ZoNF4QJ4n1w5Df/WW2/J119/LRs3bpSQkBA5c+aM6tKUePzxx8VsNsuBAwektrbW2y5fvuztk5eXJ2azWXbv3i1lZWXyyCOPXPeUclxcnOzbt09KSkpk4cKFt9Up5X8/CybCMRssfhdAIiKvvfaajBkzRgwGgyQlJXlPOd+OAFy37dixw9vH4/HI5s2bxWazidFolNTUVCkrK/M5Tmtrq2RnZ4vFYhGTySRZWVlSVVU1yM9GnasDiGM2OHg7DiJSxq/WgIhoeGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyvx/AbunDUucY7cAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "env=gym.make(\"LunarLanderContinuous-v2\",render_mode=\"rgb_array\",continuous=False)\n",
    "env.reset()\n",
    "gym_helper=GymHelper(env)\n",
    "for i in range(100):\n",
    "    gym_helper.render(title=str(i))\n",
    "    action=env.action_space.sample()\n",
    "    observation,reward,terminated,truncated,info=env.step(action)\n",
    "    done=terminated or truncated\n",
    "    if done:break\n",
    "gym_helper.render(\"finish\")\n",
    "env.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08f45fae",
   "metadata": {},
   "source": [
    "### DDPG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "be3e03b2",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-14T03:03:49.399112Z",
     "start_time": "2024-05-14T03:03:49.386608Z"
    }
   },
   "outputs": [],
   "source": [
    "#定义a2c网络\n",
    "class ActorNetwork(nn.Module):\n",
    "    def __init__(self,input_dim,output_dim):\n",
    "        super(ActorNetwork,self).__init__()\n",
    "        self.fc=nn.Sequential(\n",
    "            nn.Linear(input_dim,512),\n",
    "            nn.LayerNorm(512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512,256),\n",
    "            nn.LayerNorm(256),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(256,output_dim)\n",
    "        )\n",
    "        self.optimizer=optim.Adam(self.parameters(),lr=0.00005)\n",
    "        self.apply(weight_init)\n",
    "    def forward(self,x):\n",
    "        action=torch.tanh(self.fc(state))\n",
    "        return action\n",
    "class CriticNetwork(nn.Module):\n",
    "    def __init__(self,input_dim,output_dim):\n",
    "        super(CriticNetwork,self).__init__()\n",
    "        self.fc1=nn.Sequential(\n",
    "            nn.Linear(input_dim,512),\n",
    "            nn.LayerNorm(512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512,256),\n",
    "            nn.LayerNorm(256),\n",
    "        )\n",
    "        self.fc2=nn.Linear(output_dim,256)\n",
    "        self.q=nn.Linearn(256,1)\n",
    "        self.optimizer=optim.Adam(self.parameters(),lr=0.00005,weight_decay=0.001)\n",
    "    def forward(self,state,action):\n",
    "        x_s=self.fc(state)\n",
    "        x_a=self.fc2(action)\n",
    "        x=torch.relu(x_s+x_a)\n",
    "        q=self.q(x)\n",
    "        return q\n",
    "def weight_init(m):\n",
    "    if isinstance(m,nn.Linear):\n",
    "        nn.init_xavier_normal_(m.weight)\n",
    "        if m.bias is not None:\n",
    "            nn.init.constant_(m_bias,0,0)\n",
    "        elif isinstance(m,nn.BatchNorm1d):\n",
    "            nn.init.constant_(m.weight,1.0)\n",
    "            nn.init.constant_(m.bias,0.0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0276a121",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
